import heapq
from collections import defaultdict
from functools import cache
from heapq import heapify, heapreplace
from itertools import accumulate

from .data_struct import *


def solution_876(head: Optional[ListNode]) -> Optional[ListNode]:
    # 快慢指针
    p = q = head
    while p is not None and p.next is not None:
        p = p.next.next
        q = q.next
    return q


def solution_600(n: int) -> int:
    s = bin(n)[2:]
    length = len(s)
    dp = [[-1] * 2 for _ in range(length)]

    def f(i, j, limit):
        if i == length:
            return 1
        if not limit and dp[i][j] != -1:
            return dp[i][j]
        res = 0
        # 复杂
        if j == 1:
            res += f(i + 1, 0, limit and s[i] == '0')
        else:
            res += f(i + 1, 0, limit and s[i] == '0')
            if s[i] == '1' or not limit:
                res += f(i + 1, 1, limit and s[i] == '1')
        # up = int(s[i]) if limit else 1
        # res += f(i + 1, 0, limit and up == 0)
        # if j == 0 and up == 1:
        #     res += f(i + 1, 1, limit)
        if not limit:
            dp[i][j] = res
        return res

    return f(0, 0, True)


def solution_902(digits: List[str], n: int) -> int:
    s = str(n)
    ld = len(digits)
    ls = len(s)
    dp = [-1] * ls
    my_set = set(digits)
    min_num = int(digits[0])
    max_num = int(digits[ld - 1])

    def dfs(i, limit, num):
        if i == ls:
            return 1 if num else 0
        if not limit and dp[i] != -1:
            return dp[i]
        res = 0
        if not num:
            res += dfs(i + 1, False, False)
        low = min_num
        up = int(s[i]) if limit else max_num
        for d in range(low, up + 1):
            if str(d) in my_set:
                res += dfs(i + 1, limit and d == up, True)
        if not limit and num:
            dp[i] = res
        return res

    return dfs(0, True, False)


def solution_670(num: int) -> int:
    s = str(num)
    chs = list(enumerate(s))
    chs.sort(reverse=True, key=lambda x: x[1])
    for i in range(len(chs)):
        index, ch = chs[i]
        if ch != s[i]:
            for j in range(i + 1, len(chs)):
                if chs[j][1] == ch:
                    index = chs[j][0]
                else:
                    break
            res = s[:i] + ch + s[i + 1:index] + s[i] + s[index + 1:]
            return int(res)
    return num


def solution_670_2(num: int) -> int:
    s = list(str(num))
    max_idx = len(s) - 1
    p = q = -1
    for i in range(len(s) - 2, -1, -1):
        if s[i] > s[max_idx]:
            max_idx = i
        elif s[i] < s[max_idx]:
            p, q = i, max_idx
    if p == -1:
        return num
    s[p], s[q] = s[q], s[p]
    return int(''.join(s))


def solution_514(ring: str, key: str) -> int:
    s = [ord(c) - ord('a') for c in ring]
    t = [ord(c) - ord('a') for c in key]
    n = len(s)
    # 先算出每个字母的最后一次出现的下标
    pos = [0] * 26
    for i, c in enumerate(s):
        pos[c] = i
    # 计算每个s[i]左边a-z的最近下标
    # pos保存的是从右边数的最后一个下标，因此也就是从左边数的第一个下标
    # 上面的循环结束刚好是s[0]对应的所有字母从左数的第一个下标
    # 更新pos[c]相当于更新这个环的起始位置，更新对于s[i+1]来言，最近的s[i]所在的位置
    # 其他字母的位置不会变化
    left = [None] * n
    for i, c in enumerate(s):
        left[i] = pos[:]
        pos[c] = i

    # 对右边的计算类似,需要倒序来计算最早出现的下标
    for i in range(n - 1, -1, -1):
        pos[s[i]] = i
    right = [None] * n
    for i in range(n - 1, -1, -1):
        right[i] = pos[:]
        pos[s[i]] = i

    # 对于当前s[i],旋转到t[j]所需要的最小开销
    @cache
    def dfs(i: int, j: int) -> int:
        if j == len(t):
            return 0
        c = t[j]
        if s[i] == c:
            return dfs(i, j + 1)
        # 左侧还是右侧最小值
        l, r = left[i][c], right[i][c]
        return min(dfs(l, j + 1) + ((n - l + i) if l > i else i - l),
                   dfs(r, j + 1) + ((n - i + r) if r < i else r - i))

    return dfs(0, 0) + len(t)


def solution_514_2(ring: str, key: str) -> int:
    s = [ord(c) - ord('a') for c in ring]
    t = [ord(c) - ord('a') for c in key]
    n = len(s)
    # 先算出每个字母的最后一次出现的下标
    pos = [0] * 26
    for i, c in enumerate(s):
        pos[c] = i
    # 计算每个s[i]左边a-z的最近下标
    # pos保存的是从右边数的最后一个下标，因此也就是从左边数的第一个下标
    # 上面的循环结束刚好是s[0]对应的所有字母从左数的第一个下标
    # 更新pos[c]相当于更新这个环的起始位置，更新对于s[i+1]来言，最近的s[i]所在的位置
    # 其他字母的位置不会变化
    left = [None] * n
    for i, c in enumerate(s):
        left[i] = pos[:]
        pos[c] = i

    # 对右边的计算类似,需要倒序来计算最早出现的下标
    for i in range(n - 1, -1, -1):
        pos[s[i]] = i
    right = [None] * n
    for i in range(n - 1, -1, -1):
        right[i] = pos[:]
        pos[s[i]] = i

    # 对于当前s[i],旋转到t[j]所需要的最小开销
    memo = [[-1] * n for _ in range(len(t))]

    def dfs(i: int, j: int) -> int:
        if j == len(t):
            return 0
        if memo[j][i] != -1:
            return memo[j][i]
        c = t[j]
        if s[i] == c:
            memo[j][i] = dfs(i, j + 1)
            return memo[j][i]
        # 左侧还是右侧最小值
        l, r = left[i][c], right[i][c]
        memo[j][i] = min(dfs(l, j + 1) + ((n - l + i) if l > i else i - l),
                         dfs(r, j + 1) + ((n - i + r) if r < i else r - i))
        return memo[j][i]

    return dfs(0, 0) + len(t)


def solution_993(root: Optional[TreeNode], x: int, y: int) -> bool:
    q = [root]
    while q:
        tmp = q
        q = []
        index = -1
        x_index = -1
        y_index = -1
        for node in tmp:
            index += 1
            if node.left:
                q.append(node.left)
                if node.left.val == x:
                    x_index = index
                if node.left.val == y:
                    y_index = index
            index += 1
            if node.right:
                q.append(node.right)
                if node.right.val == x:
                    x_index = index
                if node.right.val == y:
                    y_index = index
        if x_index == -1 and y_index == -1:
            continue
        if x_index == -1 or y_index == -1:
            return False
        elif abs(x_index - y_index) == 1 and ((x_index + y_index) // 2) % 2 == 0:
            return False
        else:
            return True


def solution_987(root: Optional[TreeNode]) -> List[List[int]]:
    f = defaultdict(list)

    def helper(p: TreeNode, row, col):
        if not p:
            return
        f[(row, col)].append(p.val)
        helper(p.left, row + 1, col - 1)
        helper(p.right, row + 1, col + 1)

    helper(root, 0, 0)
    keys = []
    for key in f:
        keys.append(key)
        f[key] = sorted(f[key])
    keys.sort(key=lambda p: (p[1], p[0]))
    cur_col = keys[0][1]
    index = 0
    ans = [[]]
    for key in keys:
        if key[1] == cur_col:
            ans[index].extend(f[key])
        else:
            cur_col = key[1]
            index += 1
            ans.append(f[key])
    return ans


def solution_589(root: Optional[Node]) -> List[int]:
    if not root:
        return []
    res = []

    def helper(p):
        res.append(p.val)
        for child in p.children:
            helper(child)

    helper(root)
    return res


def solution_589_2(root: Optional[Node]) -> List[int]:
    if not root:
        return []
    ans = []
    s = [root]
    while s:
        cur = s.pop()
        ans.append(cur.val)
        for child in cur.children[::-1]:
            s.append(child)
    return ans


def solution_590(root: Optional[Node]) -> List[int]:
    if not root:
        return []
    ans = []

    def helper(node):
        for child in node.children:
            helper(child)
        ans.append(node.val)

    helper(root)
    return ans


def solution_590_2(root: Optional[Node]) -> List[int]:
    if not root:
        return []
    ans = []
    s = [root]
    prev = None
    while s:
        cur = s.pop()
        if not cur.children or prev == cur.children[-1]:
            ans.append(cur.val)
            prev = cur
            continue
        s.append(cur)
        for child in cur.children[::-1]:
            s.append(child)
    return ans


def solution_889(preorder: List[int], postorder: List[int]) -> Optional[TreeNode]:
    f = {x: i for i, x in enumerate(postorder)}

    def helper(pre_l, pre_r, post_l, post_r):
        if pre_l > pre_r or post_l > post_r:
            return None
        node = TreeNode(preorder[pre_l])
        if pre_r == pre_l:
            return node
        left_val = preorder[pre_l + 1]
        idx = f[left_val]
        length = idx - post_l + 1
        node.left = helper(pre_l + 1, pre_l + length, post_l, post_l + length - 1)
        node.right = helper(pre_l + length + 1, pre_r, post_l + length, post_r - 1)
        return node

    n = len(postorder)
    return helper(0, n - 1, 0, n - 1)


def solution_938(root: Optional[TreeNode], low: int, high: int) -> int:
    if not root:
        return 0
    x = root.val
    if x > high:
        return solution_938(root.left, low, high)
    if x < low:
        return solution_938(root.right, low, high)
    return x + solution_938(root.left, low, high) + solution_938(root.right, low, high)


def solution_665(nums: List[int]) -> bool:
    n = len(nums)
    cnt = 0
    for i in range(n - 1):
        x, y = nums[i], nums[i + 1]
        if x > y:
            cnt += 1
            if cnt > 1:
                return False
            if i > 0 and y < nums[i - 1]:  # 这种情况x不能改成y
                nums[i + 1] = x
    return True


def solution_518(amount: int, coins: List[int]) -> int:
    # f[i][j]前i种金币，还剩j金额需要凑
    n = len(coins)
    f = [[0] * (amount + 1) for _ in range(n + 1)]
    f[0][0] = 1
    for i, x in enumerate(coins):
        for cur in range(amount + 1):
            # 当前硬币选1个 f[i + 1][cur]= f[i + 1][cur - x]
            # 当前硬币不选  f[i + 1][cur]= f[i][cur]
            f[i + 1][cur] = f[i][cur] + (f[i + 1][cur - x] if cur - x >= 0 else 0)
    return f[n][amount]


def solution_518_2(amount: int, coins: List[int]) -> int:
    g = [0] * (amount + 1)
    f = [0] * (amount + 1)
    g[0] = 1
    for x in coins:
        for cur in range(amount + 1):
            # 当前硬币选1个 f[i + 1][cur]= f[i + 1][cur - x]
            # 当前硬币不选  f[i + 1][cur]= f[i][cur]
            f[cur] = g[cur] + (f[cur - x] if cur - x >= 0 else 0)
        f, g = g, f
    return g[amount]


def solution_704(nums: List[int], target: int) -> int:
    left, right = -1, len(nums)
    while left + 1 < right:
        mid = (right + left) // 2
        if nums[mid] == target:
            return mid
        elif nums[mid] > target:
            right = mid
        else:
            left = mid
    return -1


def solution_894(n: int) -> List[Optional[TreeNode]]:
    @cache
    def dfs(cnt):
        if cnt == 0:
            return []
        if cnt == 1:
            return [TreeNode(0)]
        cnt -= 1
        res = []
        for c in range(1, cnt + 1, 2):
            left = dfs(c)
            right = dfs(cnt - c)
            for l in left:
                for r in right:
                    res.append(TreeNode(0, l, r))
        return res

    return dfs(n)


def solution_707():
    # data_struct#MyLinkedList
    ...


def solution_705():
    # data_struct#MyHashSet
    ...


def solution_706():
    # data_struct#MyHashMap
    ...


def solution_924(graph: List[List[int]], initial: List[int]) -> int:
    n = len(graph)
    fa = list(range(n))
    initial.sort()

    def find(x: int) -> int:
        if fa[x] != x:
            fa[x] = find(fa[x])
        return fa[x]

    for i, row in enumerate(graph):
        for j, b in enumerate(row):
            if b:
                x = find(i)
                y = find(j)
                if x != y:
                    fa[x] = y
    cnt = defaultdict(int)
    for x in range(n):
        cnt[find(x)] += 1
    m = defaultdict(int)
    for node in initial:
        m[find(node)] += 1
    survived = 0
    idx = initial[0]
    for node in initial:
        if m[find(node)] == 1:
            cur = cnt[find(node)]
            if survived < cur:
                survived = cur
                idx = node
    return idx


def solution_928(graph: List[List[int]], initial: List[int]) -> int:
    """
    从不在initial中的点v出发DFS，在不经过initial中的节点的前提下，看看v是否只能被一个点感染到，还是被多个点感染到
    如果v只能被一个点感染到，那么本次DFS过程中访问到的其他节点，也只能被点x感染到
    """
    st = set(initial)
    vis = [False] * len(graph)

    def dfs(x: int) -> None:
        vis[x] = True
        nonlocal node_id, size
        size += 1
        for y, conn in enumerate(graph[x]):
            if conn == 0:
                continue
            if y in st:
                if node_id != -2 and node_id != y:
                    node_id = y if node_id == -1 else -2
            elif not vis[y]:
                dfs(y)

    cnt = Counter()
    for i, seen in enumerate(vis):
        if seen or i in st:
            continue
        node_id = -1
        size = 0
        dfs(i)
        if node_id >= 0:
            cnt[node_id] += size
    # 先比较size，再比较node_id
    # 用min一起比较
    return min((-size, node_id) for node_id, size in cnt.items())[1] if cnt else min(initial)


def solution_857(quality: List[int], wage: List[int], k: int) -> float:
    pairs = sorted(zip(quality, wage), key=lambda x: x[1] / x[0])
    h = [-q for q, _ in pairs[:k]]  # 起始位置是r[k-1]
    heapq.heapify(h)
    sum_q = -sum(h)
    ans = sum_q * pairs[k - 1][1] / pairs[k - 1][0]  # r值最小的k名工人
    for q, w in pairs[k:]:
        if q < -h[0]:
            sum_q += heapq.heapreplace(h, -q) + q
            ans = min(ans, sum_q * w / q)
    return ans


def solution_741(grid: List[List[int]]) -> int:
    """
    模拟两个人一起摘樱桃
    如果两个人走到一起，那么只算1个樱桃
    可以规定一个人走下侧，一个人走上侧，这样可以减少循环次数

    原始递推式
    t 走过的步数
    j A处于的行号
    k B处于的行号
    f[t][j][k] = max{   f[t-1][j  ][k  ],
                        f[t-1][j  ][k-1],
                        f[t-1][j-1][k  ],
                        f[t-1][j-1][k-1], } + val

    需要加入出界情况 j -> j+1, k -> k+1

    f[t][j+1][k+1] = max{   f[t-1][j+1][k+1],
                            f[t-1][j+1][k  ],
                            f[t-1][j  ][k+1],
                            f[t-1][j  ][k  ], } + val

    val = grid[t-j][j] + grid[t-k][k] if j!=k else grid[t-j][j]

    初始值 f[t][j][j] = -inf
    f[0][1][1] = grid[0][0]

    j 的范围
    i+j = t
    0<=i<=n-1
    0<=j<=n-1

    max(t-(n-1),0) <= j <= min(t,n-1)
    """
    n = len(grid)
    f = [[[-inf] * (n + 1) for _ in range(n + 1)] for _ in range(n * 2 - 1)]
    f[0][1][1] = grid[0][0]
    for t in range(1, n * 2 - 1):
        for j in range(max(t - n + 1, 0), min(t + 1, n)):
            if grid[t - j][j] < 0:
                continue
            for k in range(j, min(t + 1, n)):
                if grid[t - k][k] < 0:
                    continue
                res1 = max(f[t - 1][j + 1][k + 1], f[t - 1][j + 1][k], f[t - 1][j][k + 1], f[t - 1][j][k])
                res2 = grid[t - j][j] + grid[t - k][k] if j != k else grid[t - j][j]
                f[t][j + 1][k + 1] = res1 + res2
    return max(f[-1][n][n], 0)


def solution_994(grid: List[List[int]]) -> int:
    m = len(grid)
    n = len(grid[0])
    c = 0  # 新鲜橘子
    for i in range(m):
        for j in range(n):
            if grid[i][j] == 1:
                c += 1
    cnt = 0
    while c > 0:
        cnt += 1
        diff = 0
        tmp_grid = [[grid[i][j] for j in range(n)] for i in range(m)]
        s = set()
        for i in range(m):
            for j in range(n):
                if grid[i][j] == 2:
                    tmp = [(i - 1, j), (i + 1, j), (i, j + 1), (i, j - 1)]
                    for x, y in tmp:
                        if 0 <= x < m and 0 <= y < n and grid[x][y] == 1 and (x, y) not in s:
                            s.add((x, y))
                            tmp_grid[x][y] = 2
                            c -= 1
                            diff += 1
        grid = tmp_grid
        if c == 0:
            return cnt
        if diff == 0:
            return -1
    return 0


def solution_826(difficulty: List[int], profit: List[int], worker: List[int]) -> int:
    work = sorted(zip(difficulty, profit))
    worker.sort()
    ans = j = mx_profit = 0
    for w in worker:
        # 当前员工可以做
        while j < len(work) and work[j][0] <= w:
            mx_profit = max(mx_profit, work[j][1])
            j += 1
        # 委派给当前员工，下一个员工在这个基础上继续
        ans += mx_profit
    return ans


def solution_575(candyType: List[int]) -> int:
    return min(len(candyType) // 2, len(set(candyType)))


def solution_881(people: List[int], limit: int) -> int:
    people.sort()
    n = len(people)
    left, right = 0, n - 1
    cnt = 0
    while left < right:
        if people[left] + people[right] <= limit:
            left += 1
        right -= 1
        cnt += 1
    return cnt + 1 if left == right else cnt


def solution_521(a: str, b: str) -> int:
    return -1 if a == b else max(len(a), len(b))


def solution_522(strs: List[str]) -> int:
    def is_sub(s, t):
        if len(s) > len(t):
            return False
        i = 0
        for c in t:
            if s[i] == c:
                i += 1
                if i == len(s):
                    return True
        return False

    strs.sort(key=lambda s: -len(s))
    for i, s in enumerate(strs):
        if all(j == i or not is_sub(s, t) for j, t in enumerate(strs)):
            return len(s)
    return -1


def solution_868(n: int) -> int:
    if n == 0 or n & (n - 1) == 0:
        return 0
    off = 0
    last = -1
    ans = 0
    while n > 0:
        if n & 1 == 1:
            if last == -1:
                last = off
                continue
            ans = max(ans, off - last)
            last = off
        off += 1
        n = n >> 1
    return ans


def solution_693(n: int) -> bool:
    last = n & 1
    while n > 1:
        n = n >> 1
        if not (n & 1) ^ last:
            return False
        last = not last
    return True


def solution_520(word: str) -> bool:
    a = ord('a')
    flag = False
    cnt = 0
    if ord(word[0]) < a:
        flag = True
        cnt = 1
    for w in map(ord, word[1:]):
        if w < a:
            if flag:
                cnt += 1
            else:
                return False
    return cnt <= 1 or cnt == len(word)


def solution_503(nums: List[int]) -> List[int]:
    n = len(nums)
    ans = [-1] * n
    s = []
    for i in range(2 * n):
        x = nums[i % n]
        while s and x > s[-1][0]:
            ans[s[-1][1]] = x
            s.pop()
        if i < n:
            s.append((x, i))
    return ans


def solution_898(arr: List[int]) -> int:
    s = set()
    d = set()
    for x in arr:
        t = set()
        t.add(x)
        for p in d:
            t.add(p | x)
        d = t
        s |= t
    return len(s)


def solution_682(operations: List[str]) -> int:
    s = []
    for op in operations:
        if op == '+':
            s.append(s[-1] + s[-2])
        elif op == 'D':
            s.append(s[-1] * 2)
        elif op == 'C':
            s.pop()
        else:
            s.append(int(op))
    return sum(s)


def solution_699(positions: List[List[int]]) -> List[int]:
    n = len(positions)
    heights = [0] * n
    for i, (left1, side1) in enumerate(positions):
        right1 = left1 + side1 - 1
        heights[i] = side1
        for j in range(i):
            left2, side2 = positions[j]
            right2 = left2 + side2 - 1
            if right1 >= left2 and right2 >= left1:
                heights[i] = max(heights[i], heights[j] + side1)
    for i in range(1, n):
        heights[i] = max(heights[i], heights[i - 1])  # ans
    return heights


def solution_724(nums: List[int]) -> int:
    pre = list(accumulate(nums, initial=0))
    suf = list(accumulate(nums[::-1], initial=0))[::-1]
    for i in range(len(nums)):
        if pre[i] == suf[i + 1]:
            return i
    return -1


def solution_721(accounts: List[List[str]]) -> List[List[str]]:
    n = len(accounts)
    fa = list(range(n))

    def find(x: int) -> int:
        if fa[x] != x:
            fa[x] = find(fa[x])
        return fa[x]

    d = {}
    for i, acc in enumerate(accounts):
        for email in acc[1:]:
            if email in d:
                x = find(i)
                y = find(d[email])
                fa[x] = y
            else:
                d[email] = i
    g = defaultdict(set)
    for i, acc in enumerate(accounts):
        root = find(i)
        g[root].update(acc[1:])
    return [[accounts[root][0]] + sorted(emails) for root, emails in g.items()]


def solution_807(grid: List[List[int]]) -> int:
    n = len(grid)
    row = [0] * n
    col = [0] * n
    for i, r in enumerate(grid):
        row[i] = max(r)
    tmp = list(zip(*grid))
    for j in range(n):
        col[j] = max(tmp[j])
    ans = 0
    for i in range(n):
        for j in range(n):
            ans += min(row[i], col[j]) - grid[i][j]
            # grid[i][j] = min(row[i],col[j])
    # print(grid[i][j])
    return ans


def solution_911():
    # data_struct.py#TopVotedCandidate
    ...


def solution_955(strs: List[str]) -> int:
    ans = 0
    n = len(strs)
    cuts = [False] * (n - 1)

    for col in zip(*strs):  # 将每个str的同一位置的字符拿出来拼成一个新的字符串col
        if all(cuts[i] or col[i] <= col[i + 1] for i in range(n - 1)):  # 当前列有序
            for i in range(n - 1):
                if col[i] < col[i + 1]:  # 严格有序
                    cuts[i] = True
        else:
            ans += 1
    return ans


def solution_572(root: Optional[TreeNode], subRoot: Optional[TreeNode]) -> bool:
    def dfs(node, subnode):
        if not subnode and not node:
            return True
        if not subnode or not node:
            return False
        if node.val == subnode.val:
            return dfs(node.left, subnode.left) and dfs(node.right, subnode.right)
        else:
            return False

    if dfs(root, subRoot):
        return True
    else:
        ans = False
        if root.left:
            ans = ans or solution_572(root.left, subRoot)
        if root.right:
            ans = ans or solution_572(root.right, subRoot)
        return ans


def solution_572_2(root: Optional[TreeNode], subRoot: Optional[TreeNode]) -> bool:
    def getHeight(root):
        if not root:
            return 0
        left_h = getHeight(root.left)
        right_h = getHeight(root.right)
        return max(left_h, right_h) + 1

    def isSameTree(root1, root2):
        if not root1 or not root2:
            return root1 is root2
        return root1.val == root2.val and isSameTree(root1.left, root2.left) and isSameTree(root1.right, root2.right)

    hs = getHeight(subRoot)

    def dfs(node):
        if not node:
            return 0, False
        left_h, left_found = dfs(node.left)
        right_h, right_found = dfs(node.right)
        if left_found or right_found:
            return 0, True
        node_h = max(left_h, right_h) + 1
        return node_h, node_h == hs and isSameTree(node, subRoot)

    return dfs(root)[1]


def solution_600_2(n: int) -> int:
    m = n.bit_length()

    @cache
    def dfs(i, last, limit):
        if i == m:
            return 1
        up = (n >> (m - i - 1)) & 1 if limit else 1
        res = 0
        for d in range(up + 1):
            if last == 1 and d == 1:
                continue
            res += dfs(i + 1, d, limit and d == up)
        return res

    return dfs(0, 0, True)


def solution_676():
    # data_struct.cc#MagicDictionary
    ...


def solution_551(s: str) -> bool:
    cntA = 0
    cntL = 0
    for x in s:
        if x == 'P':
            cntL = 0
        elif x == 'A':
            cntA += 1
            cntL = 0
        elif x == 'L':
            cntL += 1
        if cntA == 2 or cntL == 3:
            return False
    return True


def solution_552(n: int) -> int:
    # 预处理
    MOD = 10 ** 9 + 7
    pre = [0] * 100001

    @cache
    def dfs(i, cntA, cntL):
        if i == 0:
            return 1
        res = 0
        if cntA < 1:
            res = (res + dfs(i - 1, cntA + 1, 0))  # + A
        if cntL < 2:
            res = (res + dfs(i - 1, cntA, cntL + 1)) % MOD  # + P
        res = (res + dfs(i - 1, cntA, 0)) % MOD  # + P
        return res

    for k in range(1, 100001):
        pre[k] = dfs(k, 0, 0)
    return pre[n]


def solution_698(nums: List[int], k: int) -> bool:
    total = sum(nums)
    if total % k:
        return False
    target = total // k
    nums.sort()
    if nums[0] > target or nums[-1] > target:
        return False

    @cache
    def dfs(s, cur):
        if s == 0 and cur == 0:
            return True
        t = s
        while t > 0:
            x = t & -t
            p = nums[x.bit_length() - 1] + cur
            if p <= target and dfs(s & (~x), p % target):
                return True
            elif p > target:
                break
            t &= t - 1
        return False

    return dfs((1 << len(nums)) - 1, 0)


def solution_690(employees: List['Employee'], id: int) -> int:
    f = {x.id: i for i, x in enumerate(employees)}
    ans = 0

    def dfs(x):
        nonlocal ans
        cur = employees[f[x]]
        res = cur.importance
        for y in cur.subordinates:
            res += dfs(y)
        if x == id:
            ans = res
        return res

    subs = set()
    for x in f:
        for y in employees[f[x]].subordinates:
            subs.add(y)
    for i in f:
        if i not in subs:
            dfs(i)
    return ans


def solution_977(nums: List[int]) -> List[int]:
    s = []
    ans = []
    for x in nums:
        if x < 0:
            s.append(-x)
        else:
            while s and x >= s[-1]:
                ans.append(s.pop() ** 2)
            ans.append(x ** 2)
    while s:
        ans.append(s.pop() ** 2)
    return ans


def solution_977_2(nums: List[int]) -> List[int]:
    n = len(nums)
    ans = [0] * n
    i, j = 0, n - 1
    # for p in range(n - 1, -1, -1):
    #     x = nums[i] * nums[i]
    #     y = nums[j] * nums[j]
    #     if x > y:  # 更大的数放右边
    #         ans[p] = x
    #         i += 1
    #     else:
    #         ans[p] = y
    #         j -= 1
    # return ans
    for p in range(n - 1, -1, -1):
        x, y = nums[i], nums[j]
        if -x > y:
            ans[p] = x * x
            i += 1
        else:
            ans[p] = y * y
            j -= 1
    return ans


def solution_815(routes: List[List[int]], source: int, target: int) -> int:
    if source == target:
        return 0
    seen_s = set()
    seen_b = set()
    f = defaultdict(list)
    for i, r in enumerate(routes):
        for x in r:
            f[x].append(i)
    q = list(f[source][:])
    seen_s.add(source)
    ans = 0
    while q:
        ans += 1
        tmp = q
        q = []
        for bus in tmp:
            seen_b.add(bus)
            for x in routes[bus]:
                if x == target:
                    return ans
                if x not in seen_s:
                    seen_s.add(x)
                    for y in f[x]:
                        if y not in seen_b:
                            seen_b.add(y)
                            q.append(y)
    return -1


def solution_983(days: List[int], costs: List[int]) -> int:
    n = len(days)

    @cache
    def dfs(i):
        if i == n:
            return 0
        res = dfs(i + 1) + costs[0]  # 1 day
        x = days[i]
        for j in range(i + 1, n):
            if days[j] - x + 1 > 7:
                res = min(res, dfs(j) + costs[1])
                break
        else:
            res = min(res, costs[1])
        for j in range(i + 1, n):
            if days[j] - x + 1 > 30:
                res = min(res, dfs(j) + costs[2])
                break
        else:
            res = min(res, costs[2])
        return res

    return dfs(0)


def solution_871(target: int, startFuel: int, stations: List[List[int]]) -> int:
    h = []
    cur = startFuel
    ans = 0
    for p, f in stations:
        if p <= cur:
            heappush(h, -f)
        else:
            while h and cur < p:
                t = - heappop(h)
                cur += t
                ans += 1
            if cur < p:
                return -1
            else:
                heappush(h, -f)
    while h and cur < target:
        t = - heappop(h)
        cur += t
        ans += 1
    return -1 if cur < target else ans


def solution_908(nums: List[int], k: int) -> int:
    mx = max(nums)
    mn = min(nums)
    ans = mx - mn - 2 * k
    return ans if ans > 0 else 0


def solution_910(nums: List[int], k: int) -> int:
    nums.sort()
    ans = nums[-1] - nums[0]
    for i in range(len(nums) - 1):
        mx = max(nums[i] + k, nums[-1] - k)
        mn = min(nums[i + 1] - k, nums[0] + k)
        ans = min(ans, mx - mn)
    return ans


def solution_685(edges: List[List[int]]) -> List[int]:
    n = len(edges)
    g = [[] for _ in range(n + 1)]
    f = [set() for _ in range(n + 1)]
    nvisit = set(range(1, n + 1))

    for x, y in edges:
        g[y].append(x)
        f[x].add(y)

    def dfs(x, t0, t1):
        if x == t0:
            return 0
        elif x == t1:
            return 1
        for y in f[x]:
            t = dfs(y, t0, t1)
            if t != -1:
                return t
        return -1

    if any(len(g[x]) == 0 for x in range(1, n + 1)):
        for y in range(1, n + 1):
            if len(g[y]) == 2:
                if dfs(y, g[y][0], g[y][1]) == 0:
                    return [g[y][0], y]
                else:
                    return [g[y][1], y]

    q = []
    for x in range(1, n + 1):
        if len(f[x]) == 0:
            q.append(x)
            nvisit.remove(x)
    while q:
        tmp = q
        q = []
        for y in tmp:
            x = g[y][0]
            f[x].remove(y)
            if len(f[x]) == 0:
                q.append(x)
                nvisit.remove(x)
    for x, y in edges[::-1]:
        if x in nvisit and y in nvisit:
            return [x, y]


def solution_685_2(edges: List[List[int]]) -> List[int]:
    n = len(edges)
    uf = UnionFind(n + 1)
    parent = list(range(n + 1))
    conflict = -1
    cycle = -1
    for i, (node1, node2) in enumerate(edges):
        if parent[node2] != node2:  # 两个父节点
            conflict = i
        else:
            parent[node2] = node1
            if uf.find(node1) == uf.find(node2):  # node2是根节点才会出现这种情况
                cycle = i
            else:
                uf.union(node1, node2)

    if conflict < 0:  # 只有环路
        return [edges[cycle][0], edges[cycle][1]]
    else:  # 没有连到根，也可能有环路
        conflictEdge = edges[conflict]
        if cycle >= 0:
            # conflictEdge的子节点有两个父节点
            # conflictEdge[0] 和 parent[conflictEdge[1]]
            # 如果conflictEdge导致环路，此处 cycle 不可能 >= 0
            # 因此答案只有一种可能
            return [parent[conflictEdge[1]], conflictEdge[1]]
        else:
            return [conflictEdge[0], conflictEdge[1]]


def solution_632(nums: List[List[int]]) -> List[int]:
    h = [(arr[0], i, 0) for i, arr in enumerate(nums)]
    heapify(h)

    ans_l = h[0][0]
    ans_r = r = max(arr[0] for arr in nums)
    while h[0][2] + 1 < len(nums[h[0][1]]):
        _, i, j = h[0]
        x = nums[i][j + 1]
        heapreplace(h, (x, i, j + 1))
        r = max(r, x)
        l = h[0][0]
        if r - l < ans_r - ans_l:
            ans_l, ans_r = l, r
    return [ans_l, ans_r]
